import typing
from typing import *
from torch_geometric.typing import *

import torch
from torch import Tensor
from torch_geometric.nn.conv.message_passing import *
from {{module}} import *


class Propagate_{{uid}}(NamedTuple):
{%- for k, v in prop_types.items() %}
    {{k}}: {{v}}
{%- endfor %}

{% if edge_updater_types|length > 0 %}
class EdgeUpdater_{{uid}}(NamedTuple):
{%- for k, v in edge_updater_types.items() %}
    {{k}}: {{v}}
{%- endfor %}
{% endif %}

class Collect_{{uid}}(NamedTuple):
{%- for k, v in collect_types.items() %}
    {{k}}: {{v}}
{%- endfor %}

{% if edge_updater_types|length > 0 %}
class EdgeUpdater_Collect_{{uid}}(NamedTuple):
{%- for k, v in edge_collect_types.items() %}
    {{k}}: {{v}}
{%- endfor %}
{% endif %}

class {{cls_name}}({{parent_cls_name}}):

    def _collect(
        self,
        edge_def: Union[Tensor, SparseTensor],
        size: List[Optional[int]],
        kwargs: Propagate_{{uid}},
    ) -> Collect_{{uid}}:

        init = torch.zeros(1)
        i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)
{% for arg in user_args %}
{%- if arg[-2:] not in ['_i', '_j'] %}
        {{arg}} = kwargs.{{arg}}
{%- else %}
        {{arg}}: {{collect_types[arg]}} = {% if collect_types[arg][:8] == 'Optional' %}None{% else %}init{% endif %}
        data = kwargs.{{arg[:-2]}}
        if isinstance(data, (tuple, list)):
            assert len(data) == 2
{%- if arg[-2:] == '_j' %}
            tmp = data[1]
            if isinstance(tmp, Tensor):
                self._set_size(size, 1, tmp)
            {{arg}} = data[0]
{%- else %}
            tmp = data[0]
            if isinstance(tmp, Tensor):
                self._set_size(size, 0, tmp)
            {{arg}} = data[1]
{%- endif %}
        else:
            {{arg}} = data
        if isinstance({{arg}}, Tensor):
            self._set_size(size, {% if arg[-2:] == '_j'%}0{% else %}1{% endif %}, {{arg}})
            {{arg}} = self._lift({{arg}}, edge_def, {% if arg[-2:] == "_j" %}j{% else %}i{% endif %})
{%- endif %}
{%- endfor %}

        edge_index: Optional[Tensor] = None
        adj_t: Optional[SparseTensor] = None
        edge_index_i: torch.Tensor = init
        edge_index_j: torch.Tensor = init
        ptr: Optional[Tensor] = None
        if isinstance(edge_def, Tensor):
            edge_index = edge_def
            edge_index_i = edge_def[i]
            edge_index_j = edge_def[j]
        elif isinstance(edge_def, SparseTensor):
            adj_t = edge_def
            edge_index_i, edge_index_j, value = edge_def.coo()
            ptr, _, _ = edge_def.csr()

            {% if 'edge_weight' in collect_types.keys() %}
            if edge_weight is None:
                edge_weight = value
            {% endif %}
            {% if 'edge_attr' in collect_types.keys() %}
            if edge_attr is None:
                edge_attr = value
            {% endif %}
            {% if 'edge_type' in collect_types.keys() %}
            if edge_type is None:
                edge_type = value
            {% endif %}

        {% if collect_types.get('edge_weight', 'Optional')[:8] != 'Optional' %}assert edge_weight is not None{% endif %}
        {% if collect_types.get('edge_attr', 'Optional')[:8] != 'Optional' %}assert edge_attr is not None{% endif %}
        {% if collect_types.get('edge_type', 'Optional')[:8] != 'Optional' %}assert edge_type is not None{% endif %}

        index = edge_index_i
        size_i = size[1] if size[1] is not None else size[0]
        size_j = size[0] if size[0] is not None else size[1]
        dim_size = size_i

        return Collect_{{uid}}({% for k in collect_types.keys() %}{{k}}={{k}}{{ ", " if not loop.last }}{% endfor %})

{% if edge_updater_types|length > 0 %}
    def _collect_edge(
        self,
        edge_def: Union[Tensor, SparseTensor],
        size: List[Optional[int]],
        kwargs: EdgeUpdater_{{uid}},
    ) -> EdgeUpdater_Collect_{{uid}}:

        init = torch.tensor(0.)
        i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)
{% for arg in edge_user_args %}
{%- if arg[-2:] not in ['_i', '_j'] %}
        {{arg}} = kwargs.{{arg}}
{%- else %}
        {{arg}}: {{edge_collect_types[arg]}} = {% if edge_collect_types[arg][:8] == 'Optional' %}None{% else %}init{% endif %}
        data = kwargs.{{arg[:-2]}}
        if isinstance(data, (tuple, list)):
            assert len(data) == 2
{%- if arg[-2:] == '_j' %}
            tmp = data[1]
            if isinstance(tmp, Tensor):
                self._set_size(size, 1, tmp)
            {{arg}} = data[0]
{%- else %}
            tmp = data[0]
            if isinstance(tmp, Tensor):
                self._set_size(size, 0, tmp)
            {{arg}} = data[1]
{%- endif %}
        else:
            {{arg}} = data
        if isinstance({{arg}}, Tensor):
            self._set_size(size, {% if arg[-2:] == '_j'%}0{% else %}1{% endif %}, {{arg}})
            {{arg}} = self._lift({{arg}}, edge_def, {% if arg[-2:] == "_j" %}j{% else %}i{% endif %})
{%- endif %}
{%- endfor %}

        edge_index: Optional[Tensor] = None
        adj_t: Optional[SparseTensor] = None
        edge_index_i: torch.Tensor = init
        edge_index_j: torch.Tensor = init
        ptr: Optional[Tensor] = None
        if isinstance(edge_def, Tensor):
            edge_index = edge_def
            edge_index_i = edge_def[i]
            edge_index_j = edge_def[j]
        elif isinstance(edge_def, SparseTensor):
            adj_t = edge_def
            edge_index_i, edge_index_j, value = edge_def.coo()
            ptr, _, _ = edge_def.csr()

            {% if 'edge_weight' in edge_collect_types.keys() %}
            if edge_weight is None:
                edge_weight = value
            {% endif %}
            {% if 'edge_attr' in edge_collect_types.keys() %}
            if edge_attr is None:
                edge_attr = value
            {% endif %}
            {% if 'edge_type' in edge_collect_types.keys() %}
            if edge_type is None:
                edge_type = value
            {% endif %}

        {% if edge_collect_types.get('edge_weight', 'Optional')[:8] != 'Optional' %}assert edge_weight is not None{% endif %}
        {% if edge_collect_types.get('edge_attr', 'Optional')[:8] != 'Optional' %}assert edge_attr is not None{% endif %}
        {% if edge_collect_types.get('edge_type', 'Optional')[:8] != 'Optional' %}assert edge_type is not None{% endif %}

        index = edge_index_i
        size_i = size[1] if size[1] is not None else size[0]
        size_j = size[0] if size[0] is not None else size[1]
        dim_size = size_i

        return EdgeUpdater_Collect_{{uid}}({% for k in edge_collect_types.keys() %}{{k}}={{k}}{{ ", " if not loop.last }}{% endfor %})


{% endif %}

    def propagate(
        self,
        edge_index: Union[Tensor, SparseTensor],
        {% for name, type_hint in prop_types.items() %}
        {{ name }}: {{ type_hint }},
        {% endfor %}
        size: Size=None,
    ) -> {{ prop_return_type }}:

        the_size = self._check_input(edge_index, size)
        in_kwargs = Propagate_{{uid}}({% for k in prop_types.keys() %}{{k}}={{k}}{{ ", " if not loop.last }}{% endfor %})

        {% if fuse %}
        if isinstance(edge_index, SparseTensor):
            out = self.message_and_aggregate(edge_index{% for k in msg_and_aggr_args %}, {{k}}=in_kwargs.{{k}}{% endfor %})
            return self.update(out{% for k in update_args %}, {{k}}=in_kwargs.{{k}}{% endfor %})
        {% endif %}

        kwargs = self._collect(edge_index, the_size, in_kwargs)
        out = self.message({% for k in msg_args %}{{k}}=kwargs.{{k}}{{ ", " if not loop.last }}{% endfor %})
        out = self.aggregate(out{% for k in aggr_args %}, {{k}}=kwargs.{{k}}{% endfor %})
        return self.update(out{% for k in update_args %}, {{k}}=kwargs.{{k}}{% endfor %})

{% if edge_updater_types|length > 0 %}
    def edge_updater(
        self,
        edge_index: Union[Tensor, SparseTensor],
        {% for name, type_hint in edge_updater_types.items() %}
        {{ name }}: {{ type_hint }},
        {% endfor %}
    ) -> {{ edge_updater_return_type }}:

        the_size = self._check_input(edge_index, size=None)
        in_kwargs = EdgeUpdater_{{uid}}({% for k in edge_updater_types.keys() %}{{k}}={{k}}{{ ", " if not loop.last }}{% endfor %})
        kwargs = self._collect_edge(edge_index, the_size, in_kwargs)
        return self.edge_update({% for k in edge_update_args %}{{k}}=kwargs.{{k}}{{ ", " if not loop.last }}{% endfor %})
{% else %}
    def edge_updater(self):
        pass
{% endif %}

{{ forward_repr }}
